#include "Shader.h"
#include "File.h"
#include "FileSystem.h"
#include <iostream>
#include "Graphics.h"
#include "Context.h"

inline std::vector<std::string> splitStringToVector(const std::string& text, char delimiter, bool pushEmpty)
{
	std::vector<std::string> arr;
	if (!text.empty())
	{
		std::string::size_type start = 0;
		std::string::size_type end = text.find(delimiter, start);
		while (end != std::string::npos)
		{
			std::string token = text.substr(start, end - start);
			if (!token.empty() || (token.empty() && pushEmpty)) //-V728
				arr.push_back(token);
			start = end + 1;
			end = text.find(delimiter, start);
		}
		std::string token = text.substr(start);
		if (!token.empty() || (token.empty() && pushEmpty)) //-V728
			arr.push_back(token);
	}
	return arr;
}

GLsizei GetGLAttributeSize(GLenum size)
{
	GLsizei ret = 0;
	switch (size)
	{
	case GL_BOOL:
	case GL_BYTE:
	case GL_UNSIGNED_BYTE:
		ret = 1;
		break;
	case GL_BOOL_VEC2:
	case GL_SHORT:
	case GL_UNSIGNED_SHORT:
		ret = 1;
		break;
	case GL_BOOL_VEC3:
		ret = 1;
		break;
	case GL_BOOL_VEC4:
	case GL_INT:
	case GL_UNSIGNED_INT:
	case GL_FLOAT:
		ret = 1;
		break;
	case GL_FLOAT_VEC2:
	case GL_INT_VEC2:
		ret = 2;
		break;
	case GL_FLOAT_VEC3:
	case GL_INT_VEC3:
		ret = 3;
		break;
	case GL_FLOAT_MAT2:
	case GL_FLOAT_VEC4:
	case GL_INT_VEC4:
		ret = 4;
		break;
	case GL_FLOAT_MAT3:
		ret = 9;
		break;
	case GL_FLOAT_MAT4:
		ret = 16;
		break;
	default:
		break;
	}
	return ret;
}

GLsizei GetGLDataTypeSize(GLenum size)
{
	GLsizei ret = 0;
	switch (size)
	{
	case	GL_SAMPLER_2D:
		ret = sizeof(int);
		break;
	case GL_BOOL:
	case GL_BYTE:
	case GL_UNSIGNED_BYTE:
		ret = sizeof(GLbyte);
		break;
	case GL_BOOL_VEC2:
	case GL_SHORT:
	case GL_UNSIGNED_SHORT:
		ret = sizeof(GLshort);
		break;
	case GL_BOOL_VEC3:
		ret = sizeof(GLboolean);
		break;
	case GL_BOOL_VEC4:
	case GL_INT:
	case GL_UNSIGNED_INT:
	case GL_FLOAT:
		ret = sizeof(GLfloat);
		break;
	case GL_FLOAT_VEC2:
	case GL_INT_VEC2:
		ret = sizeof(GLfloat) * 2;
		break;
	case GL_FLOAT_VEC3:
	case GL_INT_VEC3:
		ret = sizeof(GLfloat) * 3;
		break;
	case GL_FLOAT_MAT2:
	case GL_FLOAT_VEC4:
	case GL_INT_VEC4:
		ret = sizeof(GLfloat) * 4;
		break;
	case GL_FLOAT_MAT3:
		ret = sizeof(GLfloat) * 9;
		break;
	case GL_FLOAT_MAT4:
		ret = sizeof(GLfloat) * 16;
		break;
	default:
		break;
	}
	return ret;
}

void checkCompileErrors(unsigned int shader, std::string type, const char* code)
{
	int success;
	char infoLog[1024];
	if (type != "PROGRAM")
	{
		glGetShaderiv(shader, GL_COMPILE_STATUS, &success);
		if (!success)
		{
			glGetShaderInfoLog(shader, 1024, NULL, infoLog);
			std::cout << "ERROR::SHADER_COMPILATION_ERROR of type: " << type << "\n" << infoLog << "\n -- --------------------------------------------------- -- " << std::endl;
			std::vector<std::string> out = splitStringToVector(code, '\n', true);
			for (size_t i = 0; i < out.size(); i++)
			{
				std::cout << i + 1 << " " << out[i] << std::endl;
			}
		}
	}
	else
	{
		glGetProgramiv(shader, GL_LINK_STATUS, &success);
		if (!success)
		{
			glGetProgramInfoLog(shader, 1024, NULL, infoLog);
			std::cout << "ERROR::PROGRAM_LINKING_ERROR of type: " << type << "\n" << infoLog << "\n -- --------------------------------------------------- -- " << std::endl;
		}
	}
}

GLenum glCheckError_(const char* file, int line)
{
	GLenum errorCode;
	GLenum lastErrCode = 0;
	while ((errorCode = glGetError()) != GL_NO_ERROR)
	{
		std::string error;
		switch (errorCode)
		{
		case GL_INVALID_ENUM:                  error = "INVALID_ENUM"; break;
		case GL_INVALID_VALUE:                 error = "INVALID_VALUE"; break;
		case GL_INVALID_OPERATION:             error = "INVALID_OPERATION"; break;
		case GL_STACK_OVERFLOW:                error = "STACK_OVERFLOW"; break;
		case GL_STACK_UNDERFLOW:               error = "STACK_UNDERFLOW"; break;
		case GL_OUT_OF_MEMORY:                 error = "OUT_OF_MEMORY"; break;
		case GL_INVALID_FRAMEBUFFER_OPERATION: error = "INVALID_FRAMEBUFFER_OPERATION"; break;
		}
		std::cout << error << " | " << file << " (" << line << ")" << std::endl;
		lastErrCode = errorCode;
	}
	return lastErrCode;
}

Shader::Shader(Context* context) :context_(context)
{
}

Shader::~Shader()
{
	if (programID_ > 0)
		glDeleteProgram(programID_);
}

bool Shader::BeginLoad(const char* path)
{
	std::string shaderPath = std::string("Shaders/") + path;
	std::shared_ptr<File> source = context_->GetFileSystem()->GetFile("Shaders/" + shaderPath);
	if (source)
	{
		source->Open();
		if (!ProcessSource(shaderCode_, *source))
			return false;
		return true;
	}
	return false;
}

bool Shader::ProcessSource(std::string& code, File& source)
{
	while (!source.IsEof())
	{
		std::string line = source.ReadLine();
		auto itrr = line.find("#include");
		if (itrr != line.npos)
		{
			std::string includeFileName = line.substr(itrr + 9);
			string_replase(includeFileName, "\"", " ");
			includeFileName = string_trimmed(includeFileName);

			auto includeFile = context_->GetFileSystem()->GetFile("Shaders/" + includeFileName);

			if (!includeFile)
				return false;

			includeFile->Open();

			// Add the include file into the current code recursively
			if (!ProcessSource(code, *includeFile))
				return false;
		}
		else
		{
			code += line;
			code += "\n";
		}
	}

	// Finally insert an empty line to mark the space between files
	code += "\n";

	return true;
}

void Shader::CommentOutFunction(std::string& code, const std::string& signature)
{
	unsigned startPos = code.find(signature);
	unsigned braceLevel = 0;
	if (startPos == std::string::npos)
		return;

	code.insert(startPos, "/*");

	for (unsigned i = startPos + 2 + signature.length(); i < code.length(); ++i)
	{
		if (code[i] == '{')
			++braceLevel;
		else if (code[i] == '}')
		{
			--braceLevel;
			if (braceLevel == 0)
			{
				code.insert(i + 1, "*/");
				return;
			}
		}
	}
}

unsigned Shader::GetProgramObject()
{
	if (programID_ > 0)
		return programID_;

	// Comment out the unneeded shader function
	std::string strDef = "";
	for (size_t i = 0; i < defines_.size(); i++)
	{
		strDef += "#define " + defines_[i];
		strDef += '\n';
	}
	std::string vsSourceCode_ = "#version 330 core\n #define COMPILEVS\n" + strDef + shaderCode_;
	std::string psSourceCode_ = "#version 330 core\n #define COMPILEPS\n" + strDef + shaderCode_;
	CommentOutFunction(vsSourceCode_, "void PS(");
	CommentOutFunction(psSourceCode_, "void VS(");

	string_replase(vsSourceCode_, "void VS(", "void main(");
	string_replase(psSourceCode_, "void PS(", "void main(");

	std::string vertexCode;
	std::string fragmentCode;

	const char* vShaderCode = vsSourceCode_.c_str();
	const char* fShaderCode = psSourceCode_.c_str();
	// 2. compile shaders
	unsigned int vertex, fragment;
	// vertex shader
	vertex = glCreateShader(GL_VERTEX_SHADER);
	glShaderSource(vertex, 1, &vShaderCode, NULL);
	glCompileShader(vertex);
	checkCompileErrors(vertex, "VERTEX", vShaderCode);
	// fragment Shader
	fragment = glCreateShader(GL_FRAGMENT_SHADER);
	glShaderSource(fragment, 1, &fShaderCode, NULL);
	glCompileShader(fragment);
	checkCompileErrors(fragment, "FRAGMENT", fShaderCode);
	// shader Program
	programID_ = glCreateProgram();
	glAttachShader(programID_, vertex);
	glAttachShader(programID_, fragment);
	glLinkProgram(programID_);
	checkCompileErrors(programID_, "PROGRAM", "");
	// delete the shaders as they're linked into our program now and no longer necessary
	glDeleteShader(vertex);
	glDeleteShader(fragment);

	return programID_;
}

void Shader::GetActiveAttributes(unsigned programID, std::unordered_map<std::string, AttributeBindInfo>& attributeBindInfo)
{
	std::unordered_map<std::string, AttributeBindInfo>& attributes = attributeBindInfo;

	if (!programID) return;

	GLint numOfActiveAttributes = 0;
	glGetProgramiv(programID, GL_ACTIVE_ATTRIBUTES, &numOfActiveAttributes);


	if (numOfActiveAttributes <= 0)
		return;

	attributes.reserve(numOfActiveAttributes);

	std::string attrName;
	attrName.resize(MAX_ATTRIBUTE_NAME_LENGTH + 1);

	GLint attrNameLen = 0;
	GLenum attrType;
	GLint attrSize;
	AttributeBindInfo info;

	for (int i = 0; i < numOfActiveAttributes; i++)
	{
		glGetActiveAttrib(programID, i, MAX_ATTRIBUTE_NAME_LENGTH, &attrNameLen, &attrSize, &attrType, &attrName[0]);
		info.attributeName = std::string(attrName.data(), attrName.data() + attrNameLen);
		info.location = glGetAttribLocation(programID, info.attributeName.c_str());
		info.type = attrType;
		info.size = GetGLDataTypeSize(attrType) * attrSize;
		info.num = GetGLAttributeSize(attrType);
		attributes[info.attributeName] = info;
	}
}

void Shader::ComputeUniformInfos(unsigned programID, unsigned& buffsize, std::unordered_map<std::string, UniformInfo>& activeUniformInfos)
{
	if (!programID)
		return;

	GLint numOfUniforms = 0;
	glGetProgramiv(programID, GL_ACTIVE_UNIFORMS, &numOfUniforms);
	if (!numOfUniforms)
		return;

#define MAX_UNIFORM_NAME_LENGTH 256

	UniformInfo uniform;
	GLint length = 0;
	activeUniformInfos.clear();
	GLchar* uniformName = (GLchar*)malloc(MAX_UNIFORM_NAME_LENGTH + 1);
	buffsize = 0;
	for (int i = 0; i < numOfUniforms; ++i)
	{
		glGetActiveUniform(programID, i, MAX_UNIFORM_NAME_LENGTH, &length, &uniform.count, &uniform.type, uniformName);
		uniformName[length] = '\0';

		if (length > 3)
		{
			char* c = strrchr(uniformName, '[');
			if (c)
			{
				*c = '\0';
				uniform.isArray = true;
			}
		}
		uniform.location = glGetUniformLocation(programID, uniformName);
		uniform.size = GetGLDataTypeSize(uniform.type);
		uniform.bufferOffset = (uniform.size == 0) ? 0 : buffsize;
		activeUniformInfos[uniformName] = uniform;
		buffsize += uniform.size * uniform.count;
	}
	free(uniformName);
}

ProgramState::ProgramState(Context* context) :context_(context)
{
}

ProgramState::~ProgramState()
{
	if (uniformBuffer_)
		delete[] uniformBuffer_;
}

void ProgramState::InitWithShader(Shader* shader)
{
	if (!shader || shader_ == shader)
	{
		return;
	}

	shader_ = shader;
	attributeBindInfo_.clear();
	activeUniformInfos_.clear();
	programID_ = shader->GetProgramObject();

	if (uniformBuffer_)
		delete[] uniformBuffer_;

	uniformBuffer_ = nullptr;
	totalBufferSize_ = 0;
	shader->GetActiveAttributes(programID_, attributeBindInfo_);
	shader->ComputeUniformInfos(programID_, totalBufferSize_, activeUniformInfos_);
	uniformBuffer_ = new char[totalBufferSize_];
	memset(uniformBuffer_, 0, totalBufferSize_);
}

void ProgramState::SetUniform(const char* name, const void* data)
{
	if (!data)
	{
		return;
	}

	auto info = activeUniformInfos_.find(name);
	if (info != activeUniformInfos_.end())
	{
		unsigned offset = info->second.bufferOffset;
		unsigned size = info->second.size * info->second.count;
		if (memcmp(uniformBuffer_ + offset, data, size) != 0)
		{
			info->second.isDirty = true;
			memcpy(uniformBuffer_ + offset, data, size);
		}
		else
		{
			info->second.isDirty = false;
		}
	}

}

void ProgramState::Apply(bool force)
{
	for (auto& iter : activeUniformInfos_)
	{
		auto& uniformInfo = iter.second;
		if (uniformInfo.size <= 0)
			continue;

		if (uniformInfo.isDirty || force)
		{
			int elementCount = uniformInfo.count;
			context_->GetGraphics()->SetUniform(uniformInfo.isArray,
				uniformInfo.location,
				elementCount,
				uniformInfo.type,
				(void*)(uniformBuffer_ + uniformInfo.bufferOffset));

			if (uniformInfo.isDirty)
			{
				uniformInfo.isDirty = false;
			}
		}
	}
}

